{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 融合 Transpose 和 Matmul\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relax\n",
    "from tvm.script import ir as I\n",
    "from tvm.script import relax as R\n",
    "from tvm.script import tir as T\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class Before:\n",
    "    @R.function\n",
    "    def main(\n",
    "        x: R.Tensor((128, 256), \"float32\"),\n",
    "        w: R.Tensor((128, 256), \"float32\"),\n",
    "    ) -> R.Tensor((128, 128), \"float32\"):\n",
    "        with R.dataflow():\n",
    "            wT = R.permute_dims(w, [1, 0])\n",
    "            o = R.matmul(x, wT)\n",
    "            R.output(o)\n",
    "        return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "after = tvm.ir.transform.Sequential(\n",
    "    [\n",
    "        relax.transform.FuseTransposeMatmul(),\n",
    "        relax.transform.FuseTIR(),  # Only used for remove unused primitive function\n",
    "    ]\n",
    ")(Before)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), w: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            wT: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>permute_dims(w, axes<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>])\n",
       "            o: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>matmul(x, wT, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(o)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> o\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Before.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func(private<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">NT_matmul</span>(x: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), w: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), NT_matmul: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;NT_matmul&quot;</span>):\n",
       "                v_i0, v_i1, v_k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(x[v_i0, v_k], w[v_i1, v_k])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(NT_matmul[v_i0, v_i1])\n",
       "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>init():\n",
       "                    NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
       "                NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">+</span> x[v_i0, v_k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> w[v_i1, v_k]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), w: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>NT_matmul, (x, w), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 融合 Transpose 和 常量 Matmul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = relax.const(np.random.uniform(-1e-3, 1e-3, (128, 256)), \"float32\")\n",
    "\n",
    "@I.ir_module\n",
    "class Before:\n",
    "    @R.function\n",
    "    def main(\n",
    "        x: R.Tensor((128, 256), \"float32\"),\n",
    "    ) -> R.Tensor((128, 128), \"float32\"):\n",
    "        with R.dataflow():\n",
    "            wT = R.permute_dims(w, [1, 0])\n",
    "            o = R.matmul(x, wT)\n",
    "            R.output(o)\n",
    "        return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "after = tvm.ir.transform.Sequential(\n",
    "    [\n",
    "        relax.transform.FuseTransposeMatmul(),\n",
    "        relax.transform.FuseTIR(),  # Only used for remove unused primitive function\n",
    "    ]\n",
    ")(Before)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            wT: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>permute_dims(metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], axes<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">0</span>])\n",
       "            o: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>matmul(x, wT, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(o)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> o\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Before.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func(private<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">NT_matmul</span>(x: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), NT_matmul: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">128</span>), T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;NT_matmul&quot;</span>):\n",
       "                v_i0, v_i1, v_k <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(x[v_i0, v_k], B[v_i1, v_k])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(NT_matmul[v_i0, v_i1])\n",
       "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>init():\n",
       "                    NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
       "                NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">=</span> NT_matmul[v_i0, v_i1] <span style=\"color: #AA22FF; font-weight: bold\">+</span> x[v_i0, v_k] <span style=\"color: #AA22FF; font-weight: bold\">*</span> B[v_i1, v_k]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>NT_matmul, (x, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">128</span>, <span style=\"color: #008000\">128</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "after.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
